{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ozc9I4X1C4nX"
      },
      "source": [
        "Copyright 2022 Google LLC.\n",
        "\n",
        "Licensed under the Apache License, Version 2.0 (the \"License\");"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "aIZExrKYCqRi"
      },
      "outputs": [],
      "source": [
        "#@title License\n",
        "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4mdq2mi2Y7_N"
      },
      "source": [
        "# Shortcut testing with synthetic images\n",
        "\n",
        "We illustrate the method proposed in Brown et al. 2022 (https://arxiv.org/abs/2207.10384) to detect model shortcutting when a model relies on both signals related to the label Y and to an auxiliary/sensitive attribute A.\n",
        "\n",
        "To demonstrate ShorT, we need a dataset such that the main (Y) and auxiliary (A) tasks are not too easy and \"help\" each other. We refer to the MNIST dataset, with the aim to discriminate between numbers smaller than 5 and larger than 5. We add a confounder to each image, i.e. a small colored squared. The color of the square (red or green) can then be correlated to the label in the image to create a spurious signal. We add significnat amounts of noise to both signals to avoid a \"binary\" behavior of the model (i.e. when it fully relies on one signal or the other, but never on both).\n",
        "\n",
        "We demonstrate how ShorT identifies the confounding, when present."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-UcYpkXxZApt"
      },
      "source": [
        "# Imports"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "z9NNjxN4XXD3"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' \n",
        "\n",
        "import importlib\n",
        "import random\n",
        "\n",
        "from typing import *\n",
        "\n",
        "import attr\n",
        "import matplotlib as mpl\n",
        "import matplotlib.pyplot as plt\n",
        "import math\n",
        "import numpy as np\n",
        "import seaborn as sns\n",
        "import scipy\n",
        "\n",
        "from matplotlib.cm import get_cmap\n",
        "from matplotlib import gridspec\n",
        "from matplotlib import pyplot as plt\n",
        "from matplotlib import cm\n",
        "\n",
        "import sklearn\n",
        "from sklearn.utils import shuffle\n",
        "\n",
        "import tensorflow as tf\n",
        "from tensorflow import keras\n",
        "from keras import backend as K\n",
        "import numpy as np\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "LGs2__K3Y4ew"
      },
      "outputs": [],
      "source": [
        "!pip install ml_collections"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "SUFJth7Qby8R"
      },
      "outputs": [],
      "source": [
        "import ml_collections as mlc"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Hsk4c0jwcLa8"
      },
      "source": [
        "# Synthetic dataset from MNIST\n",
        "Images are generated by adding a red or green square on an MNIST image ([Deng et al., 2012](https://ieeexplore.ieee.org/document/6296535)). We select our labels to be an image with a number smaller than 5 (`label=0`) or larger than 5 (`label=1`). Two parameters specify the proportion (in %) of each label being red. A larger number (\u003e 50%) means that the class label will be correlated with the red color, while a smaller number (\u003c 50%) means that the class label will be correlated with the green color.\n",
        "\n",
        "Refer to the config dictionary (`cfg_data`) to specify all the data generation parameters. Please note that it is important to check that the tasks (predicting the label and predicting the attribute) are not trivial or too simple."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "eZZr6SzLdHdl"
      },
      "outputs": [],
      "source": [
        "#@title Data utils\n",
        "\n",
        "def add_square_to_image(image, color, square_size, noise_thresh = 0.1):\n",
        "  image_size = image.shape\n",
        "  start_x = np.random.choice(image_size[0]-square_size)\n",
        "  end_x = min(image_size[0], start_x+square_size)\n",
        "  start_y = np.random.choice(image_size[1]-square_size)\n",
        "  end_y = min(image_size[1], start_y+square_size)\n",
        "  square = np.random.uniform(\n",
        "      size=(square_size, square_size, image_size[-1])) \u003e= noise_thresh\n",
        "  image[start_x:end_x, start_y:end_y, :] = square * np.array(color) #color\n",
        "  return image, square, start_x, start_y\n",
        "\n",
        "def create_images_dataset(corr_a_y0, corr_a_y1, n_images_per_class, n_labels,\n",
        "                          select_labels, n_colors, colors, color_map,\n",
        "                          noise_thresh, conf_noise=0.1, mnist_split=\"train\"):\n",
        "\n",
        "  cls0_corr = float(corr_a_y0) / 100\n",
        "  cls1_corr = float(corr_a_y1) / 100\n",
        "\n",
        "  conprob_colors_orie = np.array([[cls0_corr, 1 - cls0_corr],\n",
        "                                  [cls1_corr, 1 - cls1_corr]])\n",
        "  \n",
        "  \n",
        "  # Load MNIST\n",
        "  (train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()\n",
        "  if mnist_split == \"train\":\n",
        "    images = train_images\n",
        "    labels = train_labels\n",
        "  else:\n",
        "    images = test_images\n",
        "    labels = test_labels\n",
        "  \n",
        "  n_images = n_images_per_class * n_labels\n",
        "  n_pixels = images[0].shape[0]\n",
        "  images_colored = np.zeros((n_images, n_pixels, n_pixels, 3))\n",
        "  flip_images = np.zeros((n_images, n_pixels, n_pixels, 3))\n",
        "  image_labels = np.zeros((n_images, 1), dtype=int)\n",
        "  color_labels = np.zeros((n_images, 1), dtype=int)\n",
        "\n",
        "  for i, label in enumerate(select_labels):\n",
        "    selected = labels == label[0]\n",
        "    j = 1\n",
        "    while j\u003clen(label):\n",
        "      selected = np.logical_or(selected, labels == label[j])\n",
        "      j += 1\n",
        "    selected_label = labels[selected]\n",
        "    selected_images = images[selected, :,:]\n",
        "    for class_image_index in range(n_images_per_class):\n",
        "      image_index = i * n_images_per_class + class_image_index\n",
        "      color = np.random.choice(n_colors, p=conprob_colors_orie[i])\n",
        "      flip_color = 1 if color == 0 else 0\n",
        "      # randomly select image from MNIST\n",
        "      idx = np.random.choice(selected_label.shape[0], replace=True)\n",
        "      image = np.reshape(selected_images[idx,:,:], (n_pixels,n_pixels, 1))\n",
        "      channels = np.concatenate((image, image, image), axis=2)\n",
        "      # Add colored square to the image\n",
        "      images_colored[image_index], square, x, y = add_square_to_image(channels/255,\n",
        "                                                        color_map[colors[color]],\n",
        "                                                        int(n_pixels/7),\n",
        "                                                        noise_thresh=conf_noise)\n",
        "      flip_images[image_index] = channels/255\n",
        "      sq_szx, sq_szy, _ = square.shape\n",
        "      flip_images[image_index,x:x+sq_szx,y:y+sq_szy, :] = square * color_map[colors[flip_color]]\n",
        "\n",
        "      noise_img = np.random.uniform(\n",
        "          size=[n_pixels, n_pixels, 1]) \u003c noise_thresh  # white noise\n",
        "      noise_img = np.concatenate((noise_img, noise_img, noise_img), axis=2)\n",
        "      noise_img =  noise_img.astype('float64')\n",
        "      noise_img[x:x+sq_szx,y:y+sq_szy, :] = square * (0.,0.,0.) # remove noise from added square\n",
        "\n",
        "      images_colored[image_index] = np.clip(\n",
        "          images_colored[image_index] + noise_img, 0.0, 1.0)\n",
        "      flip_images[image_index] = np.clip(\n",
        "          flip_images[image_index] + noise_img, 0.0, 1.0)\n",
        "\n",
        "      image_labels[image_index,:] = i\n",
        "      color_labels[image_index,:] = color\n",
        "\n",
        "  return images_colored, image_labels, color_labels, flip_images\n",
        "\n",
        "\n",
        "def gen_dataset(cfg_data):\n",
        "  \"\"\"Generates bars dataset.\n",
        "\n",
        "  Each image contains a single bar which is either horizontal (class 0) or\n",
        "  vertical (class 1)\n",
        "  The bar is either red (concept 0) or green (concept 1)\n",
        "  \"\"\"\n",
        "\n",
        "  tr_im, tr_lab, tr_col_lab, tr_fl_im = create_images_dataset(**cfg_data)\n",
        "  \n",
        "  train_images, train_labels, train_color_labels, train_flip_images = \\\n",
        "  sklearn.utils.shuffle(tr_im, tr_lab, tr_col_lab, tr_fl_im)\n",
        "\n",
        "  dataset = dict(\n",
        "      train=dict(\n",
        "          image=train_images,\n",
        "          label=train_labels,\n",
        "          color_label=train_color_labels,\n",
        "          flip_image=train_flip_images),\n",
        "  )\n",
        "  return dataset"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "A6qfNB64dU5s"
      },
      "outputs": [],
      "source": [
        "#@title Synthetic data parameters\n",
        "\n",
        "cfg_data = mlc.ConfigDict()\n",
        "\n",
        "# These must show a large difference to encourage shortcutting\n",
        "cfg_data.corr_a_y0 = 20  #@param Proportion of images with label 0 being red\n",
        "cfg_data.corr_a_y1 = 95  #@param Proportion of images with label 1 being red\n",
        "\n",
        "# Note: asymmetrical, high correlations (i.e. corr_y0 + corr_y1 != 100) lead to \n",
        "# unfairness in terms equalized odds\n",
        "\n",
        "cfg_data.n_images_per_class = 5000  #@param\n",
        "\n",
        "cfg_data.noise_thresh = 0.5  # noise added to the number\n",
        "cfg_data.conf_noise = 0.6    # noise added to the confounder\n",
        "\n",
        "cfg_data.colors = ['red', 'green', 'blue']\n",
        "cfg_data.color_map = mlc.ConfigDict()\n",
        "cfg_data.color_map.red = (1., 0., 0.)\n",
        "cfg_data.color_map.green = (0., 1., 0.)\n",
        "cfg_data.color_map.blue = (0., 0., 1.)\n",
        "\n",
        "cfg_data.n_colors = 2\n",
        "cfg_data.n_labels = 2\n",
        "\n",
        "cfg_data.select_labels = [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]\n",
        "label_names = ['\u003c5', '\u003e=5']\n",
        "cfg_data.mnist_split = \"train\"\n",
        "\n",
        "# set the seed\n",
        "seed = 123\n",
        "tf.random.set_seed(seed)\n",
        "np.random.seed(seed)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wwElA1s9ebiC"
      },
      "source": [
        "Based on these parameters, we generate train, validation and test data. In addition, we can generate \"counterfactuals\" (referred to as 'flipped_images') that flip the sampled color label before generating the same image."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "yayPrkNueTCD"
      },
      "outputs": [],
      "source": [
        "#@title Train and validation data\n",
        "dataset = gen_dataset(cfg_data)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "iz3Y5yBle7TR"
      },
      "outputs": [],
      "source": [
        "#@markdown Explore the dataset\n",
        "dataset[\"train\"].keys()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Isq1LxhTfA5x"
      },
      "outputs": [],
      "source": [
        "#@markdown Compute proportions of the data to check specifications\n",
        "# The output numbers should approximate the desired correlations between Y and A\n",
        "\n",
        "n_r1 = np.sum(np.logical_and(dataset[\"train\"][\"label\"] == 0, \n",
        "                             dataset[\"train\"][\"color_label\"] == 0)) / np.sum(dataset[\"train\"][\"label\"] == 0)\n",
        "print(\"Proportion of 0 labels that are red: {}\".format(n_r1))\n",
        "\n",
        "n_r1 = np.sum(np.logical_and(dataset[\"train\"][\"label\"] == 1,\n",
        "                             dataset[\"train\"][\"color_label\"] == 0)) / np.sum(dataset[\"train\"][\"label\"] == 1)\n",
        "print(\"Proportion of 1 labels that are red: {}\".format(n_r1))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "BeLWotfXfFuZ"
      },
      "outputs": [],
      "source": [
        "#@markdown Visualize the data (5 random samples)\n",
        "\n",
        "split = \"train\"\n",
        "NUM_TEST_SAMPLES = len(dataset[split][\"label\"])\n",
        "print(NUM_TEST_SAMPLES, dataset[split][\"label\"])\n",
        "\n",
        "for idx in random.choices(range(NUM_TEST_SAMPLES), k=5):\n",
        "  plt.figure()\n",
        "  print(dataset[split][\"label\"][idx], dataset[split][\"color_label\"][idx])\n",
        "  plt.imshow(dataset[split][\"image\"][idx, ...])\n",
        "  plt.show()\n",
        "n_pixels, n_pixels, n_channels = dataset[split][\"image\"][idx, ...].shape"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "A7XPU8_ufpbw"
      },
      "outputs": [],
      "source": [
        "#@title Test data\n",
        "\n",
        "cfg_data.n_images_per_class = 2000  #@param\n",
        "cfg_data.mnist_split = \"test\"  # We select images from the \"test\" MNIST split\n",
        "\n",
        "test_dataset = gen_dataset(cfg_data)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "qQ_ZhHVHf9-W"
      },
      "outputs": [],
      "source": [
        "#@title Transform data into TF Datasets\n",
        "def tf_data(dataset, batch_size = 64, valid_prop = 0.2, is_train = True):\n",
        "\n",
        "  # Original images\n",
        "  split = \"train\"\n",
        "  training_data = tf.data.Dataset.from_tensor_slices((dataset[split][\"image\"],\n",
        "                              dataset[split][\"label\"],\n",
        "                              dataset[split][\"color_label\"]))\n",
        "  n_images = dataset[split][\"label\"].shape[0]\n",
        "  n_valid = math.floor(n_images * valid_prop)\n",
        "  valid_data = training_data.take(n_valid).cache().repeat().batch(batch_size)\n",
        "  ds = training_data.skip(n_valid).take(n_images - n_valid).cache()\n",
        "  # Flipped color on the same images\n",
        "  data_flipped = tf.data.Dataset.from_tensor_slices((dataset[split][\"flip_image\"],\n",
        "                              dataset[split][\"label\"],\n",
        "                              np.logical_not(dataset[split][\"color_label\"]).astype(float)))\n",
        "  \n",
        "  if is_train:\n",
        "    # Cannot be used as pairs due to shuffling\n",
        "    data = ds.repeat().shuffle(10000).batch(batch_size)\n",
        "    data_flipped = data_flipped.repeat().shuffle(10000).batch(batch_size)\n",
        "  else:\n",
        "    # Can be used for paired comparisons\n",
        "    data = ds.repeat().batch(batch_size)\n",
        "    data_flipped = data_flipped.repeat().batch(batch_size)\n",
        "  \n",
        "  return data, data_flipped, valid_data"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Ma-tMJPogEwe"
      },
      "outputs": [],
      "source": [
        "pos = mlc.ConfigDict()\n",
        "pos.x, pos.y, pos.a = 0, 1, 2  # Associates a position in the tf dataset to a variable\n",
        "\n",
        "# Train and validation data\n",
        "train_data, _, valid_data = tf_data(dataset, 64)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kZChfj7igeMd"
      },
      "outputs": [],
      "source": [
        "# Test data - avoinding to shuffle so we can compare the images and their counterfactuals\n",
        "test_data, test_data_flipped, _ = tf_data(test_dataset, 64, valid_prop=0.0, is_train=False)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xujPFv_tgv_R"
      },
      "source": [
        "# Models\n",
        "\n",
        "## Baseline model - SingleHead class\n",
        "We build simple MLP models that predict one binary label. This model can be used to assess the performance of the baseline model f(X) -\u003e Y.\n",
        "\n",
        "## Attribute encoding - SingleHead with frozen feature extractor\n",
        "In addition, this model can take as an extra input a frozen feature extractor. It then adds a linear layer to this feature extractor to perform the binary classification. Only the weights of this layer are tuned during training. This architecture is used to assess the level of attribute encoding of another model.\n",
        "\n",
        "## Multi-task - MultiHead with Gradient Reversal\n",
        "The architecture of the multi-task model first comprises a series of layers that represent a \"feature extractor\". On top of the feature extractor, we add two heads: \n",
        "- one head for the label, which is predicted from a single linear layer,\n",
        "- one head for an auxiliary task (here the prediction of the color of the square). This head includes at least a non-linear layer to allow a gradient scaling operation.\n",
        "\n",
        "The weight of the auxiliary loss in the total loss is a hyper-parameter in the method (weight of label loss is fixed to 1.0). A gradient reversal head is added to the auxiliary task. Gradient scaling can then be performed (positive or negative, extra hyper-parameter). Setting both hyper-parameters to 0 corresponds to the baseline single head model.\n",
        "\n",
        "\n",
        "A config (`cfg`) dictionary specifies the parameters of the MLP architecture."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "A41cqIbNi_mJ"
      },
      "outputs": [],
      "source": [
        "#@title Model utils\n",
        "\n",
        "class GradientReversal(tf.keras.layers.Layer):\n",
        "\n",
        "    @tf.custom_gradient\n",
        "    def grad_reverse(self, x):\n",
        "        y = tf.identity(x)\n",
        "        def custom_grad(dy):\n",
        "          return self.hp_lambda * dy\n",
        "        return y, custom_grad\n",
        "\n",
        "    def __init__(self, hp_lambda, **kwargs):\n",
        "        super(GradientReversal, self).__init__(**kwargs)\n",
        "        self.hp_lambda = K.variable(hp_lambda, dtype='float', name='hp_lambda')\n",
        "\n",
        "    def call(self, x, mask=None):\n",
        "        return self.grad_reverse(x)\n",
        "\n",
        "    def set_hp_lambda(self,hp_lambda):\n",
        "        K.set_value(self.hp_lambda, hp_lambda)\n",
        "    \n",
        "    def increment_hp_lambda_by(self,increment):\n",
        "        new_value = float(K.get_value(self.hp_lambda)) +  increment\n",
        "        K.set_value(self.hp_lambda, new_value)\n",
        "\n",
        "    def get_hp_lambda(self):\n",
        "        return float(K.get_value(self.hp_lambda))\n",
        "\n",
        "\n",
        "class BaselineArch():\n",
        "  \"\"\"Superclass for multihead training.\"\"\"\n",
        "\n",
        "  def __init__(self, main=\"y\", aux=None, dtype=tf.float32, pos=None):\n",
        "    \"\"\"Initializer.\n",
        "\n",
        "    Args:\n",
        "      main: name of variable for the main task\n",
        "      aux: nema of the variable for the auxiliary task\n",
        "      dtype: desired dtype (e.g. tf.float32).\n",
        "      pos: ConfigDict that specifies the index of x, y, c, w, u in data tuple.\n",
        "        Default: data is of the form (x, y, c, w, u).\n",
        "    \"\"\"\n",
        "    self.model = None\n",
        "    self.inputs = \"x\"\n",
        "    self.main = main\n",
        "    self.aux = aux\n",
        "    self.dtype = dtype\n",
        "    if pos is None:\n",
        "      pos = mlc.ConfigDict()\n",
        "      pos.x, pos.y, pos.a = 0, 1, 2\n",
        "    self.pos = pos\n",
        "\n",
        "  def get_input(self, *batch):\n",
        "    \"\"\"Fetch model input from the batch.\"\"\"\n",
        "    # first input\n",
        "    stack = tf.cast(batch[self.pos[self.inputs[0]]], self.dtype)\n",
        "    # fetch remaining ones\n",
        "    for c in self.inputs[1:]:\n",
        "      stack = tf.concat([stack, tf.cast(batch[self.pos[c]], self.dtype)],\n",
        "                        axis=1)\n",
        "    return stack\n",
        "\n",
        "  def get_output(self, *batch):\n",
        "    \"\"\"Fetch outputs from the batch.\"\"\"\n",
        "    if self.aux:\n",
        "      return (tf.cast(batch[self.pos[self.main]],self.dtype),\n",
        "              tf.cast(batch[self.pos[self.aux]], self.dtype))\n",
        "    else:\n",
        "      return (tf.cast(batch[self.pos[self.main]],self.dtype))\n",
        "\n",
        "  def split_batch(self, *batch):\n",
        "    \"\"\"Split batch into input and output.\"\"\"\n",
        "    return self.get_input(*batch), self.get_output(*batch)\n",
        "\n",
        "  def fit(self, data: tf.data.Dataset, **kwargs):\n",
        "    \"\"\"Fit model on data.\"\"\"\n",
        "    ds = data.map(self.split_batch)\n",
        "    self.model.fit(ds, **kwargs)\n",
        "\n",
        "  def predict(self, model_input, **kwargs):\n",
        "    \"\"\"Predict target Y given the model input. See also: predict_mult().\"\"\"\n",
        "    y_pred = self.model.predict(model_input, **kwargs)\n",
        "    return y_pred\n",
        "\n",
        "  def predict_mult(self, data: tf.data.Dataset, num_batches: int, **kwargs):\n",
        "    \"\"\"Predict target Y from the TF dataset directly. See also: predict().\"\"\"\n",
        "    y_true = []\n",
        "    y_pred = []\n",
        "    ds_iter = iter(data)\n",
        "    for _ in range(num_batches):\n",
        "      batch = next(ds_iter)\n",
        "      model_input, y = self.split_batch(*batch)\n",
        "      y_true.extend(y)\n",
        "      y_pred.extend(self.predict(model_input, **kwargs))\n",
        "    return np.array(y_true), np.array(y_pred)\n",
        "\n",
        "  def score(self, data: tf.data.Dataset, num_batches: int, \n",
        "            metric: tf.keras.metrics.Metric , **kwargs):\n",
        "    \"\"\"Evaluate model on data.\n",
        "\n",
        "    Args:\n",
        "      data: TF dataset.\n",
        "      num_batches: number of batches fetched from the dataset.\n",
        "      metric: which metric to evaluate (schrouf not be instantiated).\n",
        "      **kwargs: arguments passed to predict() method.\n",
        "\n",
        "    Returns:\n",
        "      score: evaluation score.\n",
        "    \"\"\"\n",
        "    y_true, y_pred = self.predict_mult(data, num_batches, **kwargs)\n",
        "    return metric()(y_true, y_pred).numpy()\n",
        "\n",
        "\n",
        "class MultiHead(BaselineArch):\n",
        "  \"\"\"Multihead training.\"\"\"\n",
        "\n",
        "  def __init__(self, cfg, main, aux, dtype=tf.float32, pos=None): \n",
        "    \"\"\"Initializer.\n",
        "\n",
        "    Args:\n",
        "      cfg: A config that describes the MLP architecture.\n",
        "      main: variable for the main task\n",
        "      aux: variable for the auxialiary task\n",
        "      dtype: desired dtype (e.g. tf.float32) for casting data.\n",
        "    \"\"\"\n",
        "    super(MultiHead, self).__init__(main, aux, dtype, pos)\n",
        "    self.main = \"y\"\n",
        "    self.aux = \"a\"\n",
        "    self.cfg = cfg\n",
        "    # build architecture\n",
        "    self.model, self.feat_extract = self.build()\n",
        "\n",
        "  def build(self):\n",
        "    \"\"\"Build model.\"\"\"\n",
        "    cfg = self.cfg\n",
        "    input_shape = cfg.model.x_dim\n",
        "\n",
        "    # set config params to defaults if missing\n",
        "    use_bias = cfg.model.get(\"use_bias\", True)\n",
        "    activation = cfg.model.get(\"activation\", \"relu\")\n",
        "    output_activation = cfg.model.get(\"output_activation\", \"sigmoid\")\n",
        "\n",
        "    model_input = tf.keras.Input(shape=input_shape)\n",
        "    flatten_input = tf.keras.layers.Flatten()(model_input)\n",
        "    if cfg.model.depth:\n",
        "      x = tf.keras.layers.Dense(cfg.model.width, use_bias=use_bias,\n",
        "                                activation=activation,\n",
        "                                kernel_regularizer=cfg.model.regularizer)(flatten_input)\n",
        "      for _ in range(cfg.model.depth - 1):\n",
        "        x = tf.keras.layers.Dense(cfg.model.width, use_bias=use_bias,\n",
        "                                  activation=activation,\n",
        "                                  kernel_regularizer=cfg.model.regularizer)(x)\n",
        "    else:\n",
        "      x = flatten_input\n",
        "    feature_extractor = tf.keras.models.Model(inputs=flatten_input,\n",
        "                                              outputs=x)\n",
        "    # output layer - a single linear layer\n",
        "    y = tf.keras.layers.Dense(cfg.model.output_dim,\n",
        "                              use_bias=cfg.model.use_bias,\n",
        "                              name=\"output\",\n",
        "                              activation=output_activation,\n",
        "                              kernel_regularizer=cfg.model.regularizer)(x)\n",
        "    # attribute layer - an extra dense layer is required for gradients to flow back\n",
        "    attr_activation = cfg.model.get(\"attr_activation\", \"sigmoid\")\n",
        "    input_branch_a = GradientReversal(hp_lambda=cfg.model.attr_grad_updates)(x)\n",
        "    a_branch = tf.keras.layers.Dense(cfg.model.branch_dim,\n",
        "                    use_bias=cfg.model.use_bias,\n",
        "                    name=\"attr_branch\",\n",
        "                    activation=activation,\n",
        "                    kernel_regularizer=cfg.model.regularizer)(input_branch_a)\n",
        "    a = tf.keras.layers.Dense(cfg.model.attr_dim,\n",
        "                        use_bias=cfg.model.use_bias,\n",
        "                        name=\"attribute\",\n",
        "                        activation=attr_activation,\n",
        "                        kernel_regularizer=cfg.model.regularizer)(a_branch)\n",
        "    \n",
        "\n",
        "\n",
        "    # choose optimizer\n",
        "    if cfg.opt.name == \"sgd\":\n",
        "      opt = tf.keras.optimizers.SGD(learning_rate=cfg.opt.learning_rate,\n",
        "                                    momentum=cfg.opt.get(\"momentum\", 0.9))\n",
        "    elif cfg.opt.name == \"adam\":\n",
        "      opt = tf.keras.optimizers.Adam(learning_rate=cfg.opt.learning_rate)\n",
        "    else:\n",
        "      raise ValueError(\"Unrecognized optimizer type.\"\n",
        "                       \"Please select either 'sgd' or 'adam'.\")\n",
        "\n",
        "    # define losses\n",
        "    losses = {\n",
        "        \"output\": cfg.model.get(\"output_loss\", \"binary_crossentropy\"),\n",
        "        \"attribute\": cfg.model.get(\"attribute_loss\", \"binary_crossentropy\")\n",
        "    }\n",
        "    loss_weights = {\"output\": 1.0,\n",
        "                    \"attribute\": cfg.get(\"attr_loss_weight\", 1.0)}\n",
        "    metrics = {\"output\": tf.keras.metrics.AUC(),\n",
        "               \"attribute\": tf.keras.metrics.AUC()}\n",
        "\n",
        "    # build model\n",
        "    model = tf.keras.models.Model(inputs=model_input, outputs=[y,a])\n",
        "    model.build(input_shape)\n",
        "    # model.compile(optimizer=opt, loss=tf.keras.losses.BinaryCrossentropy(),\n",
        "    #               metrics=tf.keras.metrics.BinaryAccuracy())\n",
        "    model.compile(optimizer=opt, loss=losses, loss_weights=loss_weights,\n",
        "                  metrics=metrics)\n",
        "    return model, feature_extractor\n",
        "\n",
        "  def predict_mult(self, data: tf.data.Dataset, num_batches: int, **kwargs):\n",
        "    \"\"\"Predict from the TF dataset directly. See also: predict().\"\"\"\n",
        "    # infer dimensions\n",
        "    pos = self.pos\n",
        "    batch = next(iter(data))\n",
        "    y_dim = batch[pos.y].shape[1]\n",
        "    a_dim = batch[pos.a].shape[1]\n",
        "\n",
        "    # begin\n",
        "    data_iter = iter(data)\n",
        "    a_true_all = np.array([]).reshape((0, a_dim))\n",
        "    a_pred_all = np.array([]).reshape((0, a_dim))\n",
        "    y_true_all = np.array([]).reshape((0, y_dim))\n",
        "    y_pred_all = np.array([]).reshape((0, y_dim))\n",
        "\n",
        "    for _ in range(num_batches):\n",
        "      batch = next(data_iter)\n",
        "      x, y_true, a_true = batch[pos.x], batch[pos.y], batch[pos.a]\n",
        "      y_pred, a_pred = self.predict(x, **kwargs)\n",
        "      a_true_all = np.append(a_true_all, a_true, axis=0)\n",
        "      a_pred_all = np.append(a_pred_all, a_pred, axis=0)\n",
        "      y_true_all = np.append(y_true_all, y_true, axis=0)\n",
        "      y_pred_all = np.append(y_pred_all, y_pred, axis=0)\n",
        "\n",
        "    return (y_true_all, a_true_all), (y_pred_all, a_pred_all)\n",
        "\n",
        "  def score(self, data: tf.data.Dataset, num_batches: int, \n",
        "            metric: tf.keras.metrics.Metric, **kwargs):\n",
        "    \"\"\"Evaluate model on data.\n",
        "\n",
        "    Args:\n",
        "      data: TF dataset.\n",
        "      num_batches: number of batches fetched from the dataset.\n",
        "      metric: which metric to evaluate (should not be instantiated).\n",
        "      **kwargs: arguments passed to predict() method.\n",
        "\n",
        "    Returns:\n",
        "      score: evaluation score.\n",
        "    \"\"\"\n",
        "    out_true, out_pred = self.predict_mult(data, num_batches, **kwargs)\n",
        "    scores = []\n",
        "    for head in range(len(out_true)):\n",
        "      score = metric()(out_true[head], out_pred[head])\n",
        "      scores.append(score.numpy())\n",
        "    return scores\n",
        "\n",
        "\n",
        "# Can be used as a single task model fully trained or from a pre-trained\n",
        "# feature extractor\n",
        "\n",
        "class SingleHead(BaselineArch):\n",
        "  \"\"\"Singlehead training.\"\"\"\n",
        "\n",
        "  def __init__(self, cfg, main, dtype=tf.float32, pos=None, feat_extract=None): \n",
        "    \"\"\"Initializer.\n",
        "\n",
        "    Args:\n",
        "      cfg: A config that describes the MLP architecture.\n",
        "      main: variable for the main task\n",
        "      aux: variable for the auxialiary task\n",
        "      dtype: desired dtype (e.g. tf.float32) for casting data.\n",
        "    \"\"\"\n",
        "    super(SingleHead, self).__init__(main, None, dtype, pos)\n",
        "    self.main = \"a\"\n",
        "    self.cfg = cfg\n",
        "    # build architecture\n",
        "    self.model = self.build(feat_extract)\n",
        "\n",
        "  def build(self, feat_extract=None):\n",
        "    \"\"\"Build model.\"\"\"\n",
        "    cfg = self.cfg\n",
        "    input_shape = cfg.model.x_dim\n",
        "\n",
        "    # set config params to defaults if missing\n",
        "    use_bias = cfg.model.get(\"use_bias\", True)\n",
        "    activation = cfg.model.get(\"activation\", \"relu\")\n",
        "    output_activation = cfg.model.get(\"output_activation\", \"sigmoid\")\n",
        "\n",
        "    model_input = tf.keras.Input(shape=input_shape)\n",
        "    flatten_input = tf.keras.layers.Flatten()(model_input)\n",
        "    if not feat_extract:\n",
        "      if cfg.model.depth:\n",
        "        x = tf.keras.layers.Dense(cfg.model.width, use_bias=use_bias,\n",
        "                                  activation=activation,\n",
        "                                  kernel_regularizer=cfg.model.regularizer)(flatten_input)\n",
        "        for _ in range(cfg.model.depth - 1):\n",
        "          x = tf.keras.layers.Dense(cfg.model.width, use_bias=use_bias,\n",
        "                                    activation=activation,\n",
        "                                    kernel_regularizer=cfg.model.regularizer)(x)\n",
        "      else:\n",
        "        x = flatten_input\n",
        "      feature_extractor = x\n",
        "    else:\n",
        "      feat_extract.trainable = False\n",
        "      feature_extractor = feat_extract(flatten_input, training=False)\n",
        "  \n",
        "    # output layer\n",
        "    y = tf.keras.layers.Dense(cfg.model.output_dim,\n",
        "                              use_bias=cfg.model.use_bias,\n",
        "                              name=\"output\",\n",
        "                              activation=output_activation,\n",
        "                              kernel_regularizer=cfg.model.regularizer)(feature_extractor)  \n",
        "\n",
        "    # choose optimizer\n",
        "    if cfg.opt.name == \"sgd\":\n",
        "      opt = tf.keras.optimizers.SGD(learning_rate=cfg.opt.learning_rate,\n",
        "                                    momentum=cfg.opt.get(\"momentum\", 0.9))\n",
        "    elif cfg.opt.name == \"adam\":\n",
        "      opt = tf.keras.optimizers.Adam(learning_rate=cfg.opt.learning_rate)\n",
        "    else:\n",
        "      raise ValueError(\"Unrecognized optimizer type.\"\n",
        "                       \"Please select either 'sgd' or 'adam'.\")\n",
        "\n",
        "    # build model\n",
        "    model = tf.keras.models.Model(inputs=model_input, outputs=y)\n",
        "    model.build(input_shape)\n",
        "    model.compile(optimizer=opt,\n",
        "                  loss=cfg.model.get(\"output_loss\", \"binary_crossentropy\"),\n",
        "                  metrics=tf.keras.metrics.AUC())\n",
        "\n",
        "    return model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "VPeJWiIajIOX"
      },
      "outputs": [],
      "source": [
        "#@markdown Model parameters \n",
        "cfg = mlc.ConfigDict()\n",
        "\n",
        "cfg.model = mlc.ConfigDict()\n",
        "cfg.model.width = 10  #@param architecture width\n",
        "cfg.model.depth = 3  #@param model depth\n",
        "cfg.model.use_bias = True  # whether we add biases to activations.\n",
        "cfg.model.activation = 'relu'\n",
        "cfg.model.x_dim = (n_pixels, n_pixels, n_channels)\n",
        "cfg.model.branch_dim = 2  # architecture width within each branch\n",
        "cfg.model.regularizer = None  # replace with e.g. 'l2' for weight decay\n",
        "\n",
        "# output head\n",
        "cfg.model.output_activation = 'sigmoid'\n",
        "cfg.model.output_dim = 1\n",
        "\n",
        "# attribute head\n",
        "cfg.model.attr_activation = 'sigmoid'\n",
        "cfg.model.attr_grad_updates = float(-0.05)\n",
        "cfg.model.attr_dim = 1\n",
        "# this is a tradeoff between the loss on A and the loss on target Y.\n",
        "# If it's zero, we ignore the attribute loss completely.\n",
        "cfg.attr_loss_weight = float(1.0)\n",
        "\n",
        "cfg.opt = mlc.ConfigDict()\n",
        "cfg.opt.name = 'adam'\n",
        "cfg.opt.learning_rate = 0.001\n",
        "cfg.opt.momentum = 0.9"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cnJowASOjra_"
      },
      "source": [
        "# Baseline task model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "gd5vA7vwjvSu"
      },
      "outputs": [],
      "source": [
        "metric = tf.keras.metrics.AUC\n",
        "\n",
        "baseline = []\n",
        "for seed in [0,1,2,3,4]:\n",
        "  tf.random.set_seed(seed)\n",
        "  np.random.seed(seed)\n",
        "  enc = SingleHead(cfg, main=\"y\", dtype=tf.float32, feat_extract=None)\n",
        "  kwargs = {'epochs': 100, 'steps_per_epoch':20, 'verbose': False}\n",
        "  enc.fit(train_data, **kwargs)\n",
        "  kwargs = {'verbose': False}\n",
        "  num_batches = 30\n",
        "  sc = enc.score(valid_data, num_batches, metric, **kwargs)\n",
        "  baseline.append(sc)\n",
        "\n",
        "print(\"Baseline model: %1.2f +- %.2f\" % (np.mean(baseline),np.std(baseline)))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FMewMxgikOmD"
      },
      "source": [
        "# Bounds on attribute encoding (i.e. LEB and UEB in paper)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "1kWfJz8tkNfu"
      },
      "outputs": [],
      "source": [
        "#@title Upper bound\n",
        "metric = tf.keras.metrics.AUC\n",
        "\n",
        "upper_bound = []\n",
        "for seed in [0,1,2,3,4]:\n",
        "  tf.random.set_seed(seed)\n",
        "  np.random.seed(seed)\n",
        "  enc = SingleHead(cfg, main=\"a\", dtype=tf.float32, feat_extract=None)\n",
        "  kwargs = {'epochs': 100, 'steps_per_epoch':20, 'verbose': False}\n",
        "  enc.fit(train_data, **kwargs)\n",
        "  kwargs = {'verbose': False}\n",
        "  num_batches = 30\n",
        "  sc = enc.score(valid_data, num_batches, metric, **kwargs)\n",
        "  upper_bound.append(sc)\n",
        "\n",
        "print(\"Upper bound: %1.2f +- %.2f\" % (np.mean(upper_bound),np.std(upper_bound)))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zCavAI_Fki2W"
      },
      "outputs": [],
      "source": [
        "#@title Lower bound\n",
        "# Predict a color randomly selected from the train set (with replacement)\n",
        "\n",
        "n_train, a_dim = dataset[\"train\"][\"color_label\"].shape\n",
        "\n",
        "lower_bound = []\n",
        "for seed in [0,1,2,3,4]:\n",
        "  tf.random.set_seed(seed)\n",
        "  np.random.seed(seed)\n",
        "  data_iter = iter(train_data)\n",
        "  a_true_all = np.array([]).reshape((0, a_dim))\n",
        "  a_pred_all = np.array([]).reshape((0, a_dim))\n",
        "\n",
        "  for _ in range(num_batches):\n",
        "    batch = next(data_iter)\n",
        "    a_true = batch[pos.a]\n",
        "    a_pred = dataset[\"train\"][\"color_label\"][np.random.choice(n_train,a_true.shape[0]), ...]\n",
        "    preds = a_pred.reshape((-1,a_dim))\n",
        "    a_true_all = np.append(a_true_all, a_true, axis=0)\n",
        "    a_pred_all = np.append(a_pred_all, preds, axis=0)\n",
        "\n",
        "  sc = tf.keras.metrics.AUC()(a_true_all, a_pred_all).numpy()\n",
        "  lower_bound.append(sc)\n",
        "\n",
        "print(\"Lower bound: %1.2f +- %.2f\" % (np.mean(lower_bound),np.std(lower_bound)))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9eY_qcOPkyo5"
      },
      "source": [
        "# Example of Multi-Task model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "IHFinRviksIU"
      },
      "outputs": [],
      "source": [
        "#@title Initialize from config\n",
        "cfg.model.attr_grad_updates = float(0.1) # if both are set to 0, this is a baseline model (i.e. the same as with no extra head)\n",
        "cfg.attr_loss_weight = float(0.75)\n",
        "clf = MultiHead(cfg, main=\"y\", aux=\"a\", dtype=tf.float32, pos=None)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "LrriExy3k8cc"
      },
      "outputs": [],
      "source": [
        "#@markdown Explore the architecture\n",
        "clf.model.summary()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "rHyqFeJelCQM"
      },
      "outputs": [],
      "source": [
        "#@title Train the model\n",
        "kwargs = {'epochs': 100, 'steps_per_epoch':20, 'verbose': True}\n",
        "clf.fit(train_data, **kwargs)\n",
        "clf.trainable=False"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "n1jZXkiilipk"
      },
      "outputs": [],
      "source": [
        "#@title Define the decision threshold by maximizing the F1-score on validation data\n",
        "\n",
        "from sklearn.metrics import precision_recall_curve\n",
        "def f1_curve(truth, prediction_scores, e=1e-6):\n",
        "  precision, recall, thresholds = precision_recall_curve(truth, prediction_scores)\n",
        "  f1 = 2*recall*precision/(recall+precision+e)\n",
        "  return thresholds, f1[:-1]\n",
        "\n",
        "def threshold_at_max_f1_score(truth, prediction_scores):\n",
        "  thresholds, f1 = f1_curve(truth, prediction_scores)\n",
        "  peak_idx = np.argmax(f1)\n",
        "  return thresholds[peak_idx]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "CmlOsACMltYq"
      },
      "outputs": [],
      "source": [
        "kwargs = {'verbose': False}\n",
        "num_batches = 20\n",
        "metric = tf.keras.metrics.AUC\n",
        "scores_val = clf.score(valid_data, num_batches, metric, **kwargs)\n",
        "print(scores_val)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "KjztJ34UlxIq"
      },
      "outputs": [],
      "source": [
        "kwargs = {'verbose': False}\n",
        "yt, yp = clf.predict_mult(valid_data, num_batches, **kwargs)\n",
        "threshold = threshold_at_max_f1_score(yt[0],yp[0])\n",
        "threshold"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "ZviOthNFlycK"
      },
      "outputs": [],
      "source": [
        "#@title Fairness metrics\n",
        "\n",
        "# As per the work of Alabdulmohsin et al., 2021\n",
        "\n",
        "def fairness_metrics(y_pred, y_true, sens_attr):\n",
        "  eps = 1e-5\n",
        "  groups = np.unique(sens_attr).tolist()\n",
        "\n",
        "  max_error = 0\n",
        "  min_error = 1\n",
        "\n",
        "  max_mean_y = 0\n",
        "  min_mean_y = 1\n",
        "\n",
        "  max_mean_y0 = 0  # conditioned on y = 0\n",
        "  min_mean_y0 = 1\n",
        "\n",
        "  max_mean_y1 = 0\n",
        "  min_mean_y1 = 1\n",
        "\n",
        "  for group in groups:\n",
        "    yt = y_true[sens_attr == group].astype('int32')\n",
        "    ypt = (y_pred[sens_attr == group]).astype('int32')\n",
        "    err = -np.mean(yt * np.log(ypt+eps) + (1-yt)*np.log(1-ypt+eps))\n",
        "    mean_y = np.mean(y_pred[sens_attr == group])\n",
        "    neg = np.logical_and(sens_attr == group, y_true == 0)\n",
        "    pos = np.logical_and(sens_attr == group, y_true == 1)\n",
        "    mean_y0 = np.mean(y_pred[neg])\n",
        "    mean_y1 = np.mean(y_pred[pos])\n",
        "\n",
        "    if err \u003e max_error:\n",
        "      max_error = err\n",
        "    if err \u003c min_error:\n",
        "      min_error = err\n",
        "\n",
        "    if mean_y \u003e max_mean_y:\n",
        "      max_mean_y = mean_y\n",
        "    if mean_y \u003c min_mean_y:\n",
        "      min_mean_y = mean_y\n",
        "\n",
        "    if mean_y0 \u003e max_mean_y0:\n",
        "      max_mean_y0 = mean_y0\n",
        "    if mean_y0 \u003c min_mean_y0:\n",
        "      min_mean_y0 = mean_y0\n",
        "\n",
        "    if mean_y1 \u003e max_mean_y1:\n",
        "      max_mean_y1 = mean_y1\n",
        "    if mean_y1 \u003c min_mean_y1:\n",
        "      min_mean_y1 = mean_y1\n",
        "  \n",
        "  eo = 0.5*(max_mean_y0 - min_mean_y0 + max_mean_y1 - min_mean_y1)\n",
        "  dp = max_mean_y - min_mean_y\n",
        "  err_parity = max_error - min_error\n",
        "\n",
        "  return eo, dp, err_parity"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "YxTljd9CmD6S"
      },
      "outputs": [],
      "source": [
        "#@title Evaluate the model on test data\n",
        "kwargs = {'verbose': False}\n",
        "num_batches = 20\n",
        "metric = tf.keras.metrics.AUC\n",
        "scores = clf.score(test_data, num_batches, metric, **kwargs)\n",
        "scores"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "CYTlaLromSji"
      },
      "outputs": [],
      "source": [
        "yt, yp = clf.predict_mult(test_data, num_batches, **kwargs)\n",
        "eo, dp, ep = fairness_metrics(yp[0]\u003e=threshold,\n",
        "                              yt[0],yt[1].astype('int32'))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "rr9GtWwjmjHx"
      },
      "outputs": [],
      "source": [
        "#@title Evaluate model on counterfactual test data\n",
        "kwargs = {'verbose': False}\n",
        "num_batches = 20\n",
        "metric = tf.keras.metrics.AUC\n",
        "\n",
        "scores_flipped = clf.score(test_data_flipped, num_batches,metric, **kwargs)\n",
        "scores_flipped"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "r4YSRAk6mrow"
      },
      "outputs": [],
      "source": [
        "#@title Evaluate attribute encoding based on a frozen feature extractor from the multi-task model\n",
        "enc = SingleHead(cfg, main=\"a\", dtype=tf.float32, feat_extract=clf.feat_extract)\n",
        "kwargs = {'epochs': 30, 'steps_per_epoch':20, 'verbose': False}\n",
        "enc.fit(train_data, **kwargs)\n",
        "kwargs = {'verbose': False}\n",
        "num_batches = 20\n",
        "metric = tf.keras.metrics.AUC\n",
        "sc = enc.score(test_data, num_batches, metric, **kwargs)\n",
        "sc"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "UczpCETAmyOu"
      },
      "source": [
        "# Piecing the elements together: ShorT"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "jGsb0Tpsm37w"
      },
      "outputs": [],
      "source": [
        "#@title ShorT method\n",
        "def shortcut_testing(cfg, train_data, test_data, counterfact_data, val_data,\n",
        "                     range_grads, seeds = [0,1,2,3,4], num_epochs_train = 30,\n",
        "                     num_batch_test = 5):\n",
        "  \"\"\"Shortcut testing.\n",
        "\n",
        "  Requires:\n",
        "  - cfg: a config dictionary of model specifications\n",
        "  - train_data: tf dataset of train data respecting the positions for x,y,a\n",
        "  - test_data: tf dataset of test data\n",
        "  - counterfact_data: tf dataset of counterfactual test images\n",
        "  - val_data: validation data to compute decision threshold\n",
        "  - range_grads: np array of gradient scalings\n",
        "  - seeds: list of seed numbers to fix, per gradient scaling\n",
        "  - num_epoch_train: int, number of epochs for training\n",
        "  - num_batch_test: number of batches to test the model on\n",
        "\n",
        "  Outputs:\n",
        "  scores_y: np array of size (# gradient scalings, # seeds) of test performance on Y\n",
        "  encoding_a: np array of model's attribute encoding\n",
        "  equ_odds: np array of equalized odds on test data\n",
        "  dem_par: similar, demographic parity\n",
        "  err_par: error parity\n",
        "  count_fair: counterfactual fairness (computed locally then averaged)\n",
        "  scores_c: global performance on Y on counterfactual data\n",
        "  models: tf.Keras.model instances\n",
        "  \"\"\" \n",
        "\n",
        "\n",
        "  scores_y = np.zeros((len(range_grads), len(seeds)))\n",
        "  scores_c = np.zeros((len(range_grads), len(seeds)))\n",
        "  encoding_a = np.zeros((len(range_grads), len(seeds)))\n",
        "  equ_odds = np.zeros((len(range_grads), len(seeds)))\n",
        "  dem_par = np.zeros((len(range_grads), len(seeds)))\n",
        "  err_par = np.zeros((len(range_grads), len(seeds)))\n",
        "  count_fair = np.zeros((len(range_grads), len(seeds)))\n",
        "  models = []\n",
        "  metric = tf.keras.metrics.AUC\n",
        "  loss_weight = cfg.attr_loss_weight\n",
        "  for g, grad_scaling in enumerate(range_grads):\n",
        "    cfg.model.attr_grad_updates = float(grad_scaling)\n",
        "\n",
        "    if grad_scaling == 0:  ## baseline model with no other head\n",
        "      cfg.attr_loss_weight = float(0.0)\n",
        "    else:\n",
        "      cfg.attr_loss_weight = float(loss_weight)\n",
        "    \n",
        "    for s, seed in enumerate(seeds):\n",
        "      tf.random.set_seed(seed)\n",
        "      np.random.seed(seed)\n",
        "      # instantiate\n",
        "      clf = MultiHead(cfg, main=\"y\", aux=\"a\", dtype=tf.float32, pos=None)\n",
        "      # train the multi-head model\n",
        "      kwargs = {'epochs': num_epochs_train, 'steps_per_epoch':20, 'verbose': False}\n",
        "      clf.fit(train_data, **kwargs)\n",
        "      clf.trainable = False\n",
        "\n",
        "      # estimate attribute encoding by freezing the weights of the feature extractor\n",
        "      enc = SingleHead(cfg, main=\"a\", dtype=tf.float32, feat_extract=clf.feat_extract)\n",
        "      kwargs = {'epochs': num_epochs_train, 'steps_per_epoch':20, 'verbose': False}\n",
        "      enc.fit(train_data, **kwargs)\n",
        "      kwargs = {'verbose': False}\n",
        "      encoding_a[g,s] = enc.score(test_data, num_batch_test,\n",
        "                                  metric, **kwargs)\n",
        "\n",
        "      # Model performance on test\n",
        "      scores = clf.score(test_data, num_batch_test, metric, **kwargs)\n",
        "      scores_y[g,s] = scores[0]\n",
        "\n",
        "      # from validation data, obtain threshold for decision making\n",
        "      out_true, out_pred = clf.predict_mult(val_data,\n",
        "                                            num_batches=num_batch_test,\n",
        "                                            **kwargs)\n",
        "      threshold = threshold_at_max_f1_score(out_true[0],out_pred[0])\n",
        "\n",
        "      # Model statistical fairness metric\n",
        "      out_true, out_pred = clf.predict_mult(test_data, num_batches=num_batch_test, **kwargs)\n",
        "      eo, dp, ep = fairness_metrics(out_pred[0]\u003e=threshold,\n",
        "                                    out_true[0],out_true[1].astype('int32'))\n",
        "      equ_odds[g,s] = eo\n",
        "      dem_par[g,s] = dp\n",
        "      err_par[g,s] = ep\n",
        "      # global shortcutting\n",
        "      scores = clf.score(counterfact_data, num_batch_test, metric, **kwargs)\n",
        "      scores_c[g,s] = scores[0]\n",
        "      # counterfactual fairness metric\n",
        "      cf_true, cf_pred = clf.predict_mult(counterfact_data, num_batches=num_batch_test,\n",
        "                                          **kwargs)\n",
        "      count_fair[g,s] = np.mean(np.abs(\n",
        "          (out_pred[0]\u003e=threshold).astype(float) - (cf_pred[0]\u003e=threshold).astype(float)))\n",
        "      models.append(clf)\n",
        "      del clf\n",
        "  return scores_y, encoding_a, equ_odds, dem_par, err_par, count_fair, scores_c, models"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "cQy-hQvqorsx"
      },
      "outputs": [],
      "source": [
        "#@title Run ShorT\n",
        "\n",
        "# Ensure that these parameters cover the attribute encoding from LEB to UEB\n",
        "range_grads = [-0.09, -0.07, -0.05, -0.03, -0.02, -0.01, -0.005,0.0, 0.005, 0.01, 0.02, 0.03, 0.05,0.07, 0.09]\n",
        "\n",
        "# This should take some time to run\n",
        "acc_y, acc_a, eo, dp, ep, cf, counter_y, models = shortcut_testing(cfg, train_data,\n",
        "                                                      test_data, valid_data,\n",
        "                                                      test_data_flipped,\n",
        "                                                      range_grads,\n",
        "                                                      num_epochs_train=100,\n",
        "                                                      num_batch_test=20)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Aiw0SdbbpKw8"
      },
      "outputs": [],
      "source": [
        "#@markdown Derive the correlation\n",
        "\n",
        "filt = acc_y \u003e= 0.8  #@param User-defined performance threshold. The goal is to discard trivial models\n",
        "corr, p = scipy.stats.spearmanr(acc_a[filt],cf[filt])\n",
        "print(corr)\n",
        "print(p)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "WK3lk32PpfoT"
      },
      "outputs": [],
      "source": [
        "#@title Plot utils\n",
        "\n",
        "SMALL_SIZE = 10\n",
        "MEDIUM_SIZE = 14\n",
        "BIGGER_SIZE = 18\n",
        "\n",
        "plt.rc('font', size=SMALL_SIZE)          # controls default text sizes\n",
        "plt.rc('axes', titlesize=BIGGER_SIZE)     # fontsize of the axes title\n",
        "plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels\n",
        "plt.rc('xtick', labelsize=MEDIUM_SIZE)    # fontsize of the tick labels\n",
        "plt.rc('ytick', labelsize=MEDIUM_SIZE)    # fontsize of the tick labels\n",
        "plt.rc('legend', fontsize=MEDIUM_SIZE)    # legend fontsize\n",
        "plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title\n",
        "\n",
        "palette = sns.color_palette('tab20b')\n",
        "\n",
        "def plot_scale_gradients_encoding(scale_gradients, encoding_a, \n",
        "                                  upper_bound, lower_bound):\n",
        "  \"\"\"Plots the intervention compared to attribute encoding.\n",
        "\n",
        "  Inputs:\n",
        "  scale_gradients: numpy array of scale gradients used\n",
        "  encoding_a: numpy array of attribute encoding values, as measure by transfer learning\n",
        "  upper_bound: list or numpy array of maximum attribute encoding\n",
        "  lower_bound: list or numpy array of minimum attribute encoding\n",
        "  \"\"\"\n",
        "  \n",
        "  sg_mean = np.mean(encoding_a, axis=1)\n",
        "  sg_std = np.std(encoding_a, axis=1)\n",
        "\n",
        "  fig = plt.figure(figsize=(5,4))\n",
        "  ax = fig.add_axes([0,0,1,1])\n",
        "  ax.errorbar(scale_gradients, sg_mean, \n",
        "                yerr=sg_std, \n",
        "                fmt='x', \n",
        "                color='tab:blue',\n",
        "                ecolor='tab:blue')\n",
        "  plt.hlines(np.mean(upper_bound),np.min(scale_gradients), np.max(scale_gradients),\n",
        "            colors=[0.4,0.4,0.4],linestyles='dashed')\n",
        "  plt.hlines(np.mean(lower_bound),np.min(scale_gradients), np.max(scale_gradients),\n",
        "            colors=[0.4,0.4,0.4],linestyles='dashed')\n",
        "  ax.set_xlabel('Scale Gradient')\n",
        "  ax.set_ylabel('Attribute encoding')\n",
        "  plt.show\n",
        "\n",
        "def plot_fairness_encoding(encoding_m, fair_m, perf_m, perf_thresh = 0):\n",
        "  \"\"\"Plots fairness results vs attribute encoding.\n",
        "  \n",
        "  Inputs:\n",
        "  encoding_m: encoding metric result, as numpy array\n",
        "  fair_m: fairness metric result, as numpy array\n",
        "  perf_m: model performance on the output label, as numpy array\n",
        "  perf_thresh: what minimum performance to consider\n",
        "  \"\"\"\n",
        "\n",
        "  filt = perf_m \u003c= perf_thresh\n",
        "\n",
        "  fig = plt.figure(figsize=(5,4))\n",
        "  plt.scatter(encoding_m, fair_m,color=[0.6,0,0.2],alpha=0.5)\n",
        "  plt.scatter(encoding_m[filt], fair_m[filt],color=[0.2,0.2,0.2])\n",
        "  plt.xlabel('Attribute Accuracy')\n",
        "  plt.ylabel('Fairness')\n",
        "  plt.show()\n",
        "\n",
        "def point_z_order(c, midpoint):\n",
        "  deviation = np.zeros_like(c)\n",
        "  for i in range(c.shape[0]):\n",
        "    if c[i] \u003e midpoint:\n",
        "      deviation[i] = (c[i]-midpoint) / (np.max(c)-midpoint)\n",
        "    else:\n",
        "      deviation[i] = (midpoint-c[i]) / (midpoint-np.min(c))\n",
        "  return np.argsort(deviation)\n",
        "\n",
        "def performance_fairness_age_frontier_plot(encoding_m, fair_m, perf_m,\n",
        "                                           scale_gradients, cmap='PRGn'):\n",
        "  \n",
        "  class MidpointNormalize(mpl.colors.Normalize):\n",
        "    \"\"\"\n",
        "    class to help renormalize the color scale\n",
        "    \"\"\"\n",
        "    def __init__(self, vmin=None, vmax=None, midpoint=None, clip=False):\n",
        "        self.midpoint = midpoint\n",
        "        mpl.colors.Normalize.__init__(self, vmin, vmax, clip)\n",
        "\n",
        "    def __call__(self, value, clip=None):\n",
        "        x, y = [self.vmin, self.midpoint, self.vmax], [0, 0.5, 1]\n",
        "        return np.ma.masked_array(np.interp(value, x, y))\n",
        "\n",
        "  baseline_models = np.argwhere(np.array(scale_gradients)==0.0)[0]\n",
        "  if baseline_models.shape[0] == 0:\n",
        "    baseline_models = int(len(scale_gradients)/2)\n",
        "  midpoint = encoding_m[baseline_models,:].mean()\n",
        "  baseline_model_perf = perf_m[baseline_models,:].mean()\n",
        "  baseline_model_fair = fair_m[baseline_models,:].mean()\n",
        "\n",
        "  print(f'Baseline model Attribute encoding: {midpoint:.2f}')\n",
        "  print(f'Baseline model Performance: {baseline_model_perf:.4f}')\n",
        "  print(f'Baseline model Fairness: {baseline_model_fair:.4f}')\n",
        "\n",
        "  norm =  MidpointNormalize(midpoint = midpoint)\n",
        "\n",
        "  attr = encoding_m.flatten()\n",
        "  z_order = point_z_order(attr, midpoint)\n",
        "\n",
        "  fair =fair_m.flatten()\n",
        "  perf = perf_m.flatten()\n",
        "  fig = plt.figure(figsize=(5,4))\n",
        "  ax = fig.add_axes([0,0,1,1])\n",
        "  plt.scatter(fair[z_order], perf[z_order], s=30, c=attr[z_order], cmap=cmap, norm=norm)\n",
        "  # overplot the baseline models in red\n",
        "  plt.scatter(fair_m[baseline_models,:], perf_m[baseline_models,:], s=30,\n",
        "                   color=(0.8, 0.2, 0.2))\n",
        "  plt.colorbar(cm.ScalarMappable(norm=norm, cmap=cmap),label='Attribute Encoding')\n",
        "  plt.ylabel('Performance')\n",
        "  plt.xlabel('Fairness')\n",
        "  plt.show()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "0HJjIltJpFbr"
      },
      "outputs": [],
      "source": [
        "#@title Plot scale gradient vs age encoding\n",
        "plot_scale_gradients_encoding(range_grads, acc_a, upper_bound, lower_bound)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "sHORVCqBpnFT"
      },
      "outputs": [],
      "source": [
        "#@title Plot Fairness w.r.t. attribute encoding\n",
        "suffix = \"{}_EO\".format(cfg_data.corr_a_y1)\n",
        "plot_fairness_encoding(acc_a, cf, acc_y, perf_thresh = 0.8)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Gfx8TY70pt_D"
      },
      "outputs": [],
      "source": [
        "#@title Pareto plot of fairness-performance colored by attribute encoding\n",
        "performance_fairness_age_frontier_plot(acc_a,eo,acc_y,range_grads)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "QZIM0AucYMhX"
      },
      "outputs": [],
      "source": [
        "#@title Selecting a model and dropping the attribute head for deployment\n",
        "\n",
        "fairness_threshold = 0.09 #@param\n",
        "gradient_parameters = np.array(range_grads).reshape(-1,1).repeat(5,axis=1)\n",
        "\n",
        "possible_models = acc_y[eo\u003c=fairness_threshold]\n",
        "possible_gradients = gradient_parameters[eo\u003c=fairness_threshold]\n",
        "filt = eo\u003c=fairness_threshold\n",
        "\n",
        "ind = np.argmax(possible_models)\n",
        "selected_gradient = float(possible_gradients[ind])\n",
        "selected_model = models[np.argwhere(filt.flatten()).tolist()[ind][0]]\n",
        "print(\"Selected gradient parameter: {}\".format(selected_gradient))\n",
        "print(\"Selected model performance: {}\".format(possible_models[ind]))\n",
        "print(\"Selected model fairness: {}\".format(eo[eo\u003c=fairness_threshold][ind]))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9LRDs_1eewqb"
      },
      "outputs": [],
      "source": [
        "selected_model.model.summary()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "_NjgGzo6aVOD"
      },
      "outputs": [],
      "source": [
        "# Create a new model by selecting all layers except the ones related to the attribute\n",
        "remove = ['attribute', 'attr_branch', 'gradient']  # corresponds to the names of the layer (partial matches ok), as defined in the MultiHead class\n",
        "\n",
        "final_model = tf.keras.Sequential()\n",
        "for layer in clf.model.layers:\n",
        "  match = [to_pop in layer.name for to_pop in remove]\n",
        "  if not any(match):\n",
        "    final_model.add(layer)\n",
        "    layer.trainable = False\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "waoOroZOhiF0"
      },
      "outputs": [],
      "source": [
        "final_model.summary()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "mT0Ti50ChoIj"
      },
      "outputs": [],
      "source": [
        "x = test_dataset['train']['image']\n",
        "y_pred = final_model.predict_on_batch(x)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "PlZ_04T9imOp"
      },
      "outputs": [],
      "source": [
        "sklearn.metrics.roc_auc_score(test_dataset['train']['label'], y_pred)"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [
        "xujPFv_tgv_R",
        "cnJowASOjra_",
        "FMewMxgikOmD",
        "9eY_qcOPkyo5",
        "UczpCETAmyOu"
      ],
      "private_outputs": true,
      "provenance": [
        {
          "file_id": "1YbpbR1q6asdqS_ZfwuKl1FM__OLiGtXE",
          "timestamp": 1665077635449
        }
      ],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
